-
Notifications
You must be signed in to change notification settings - Fork 41
Fix trainer detection for custom Docker images with regex pattern matching #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix trainer detection for custom Docker images with regex pattern matching #31
Conversation
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here.
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jskswamy thank you for this contribution! And thank you for also adding tests for get_container_devices.
It looks great to me -- I just added two small suggestions.
python/tests/test_utils.py
Outdated
| # Edge cases - no match (should fall back to default) | ||
| ("unknown-image:latest", types.TrainerFramework.TORCH), | ||
| ("", types.TrainerFramework.TORCH), | ||
| ("nginx:latest", types.TrainerFramework.TORCH), | ||
| ("ubuntu:20.04", types.TrainerFramework.TORCH), | ||
| ], | ||
| ) | ||
| def test_trainer_detection_from_image_patterns( | ||
| self, image_name, expected_framework | ||
| ): | ||
| """Test trainer detection using image pattern matching with various case scenarios.""" | ||
| trainer = utils._detect_trainer_from_image_patterns(image_name) | ||
| if expected_framework == types.TrainerFramework.TORCH and trainer is None: | ||
| # For unknown images, the _detect_trainer function should return default | ||
| # but _detect_trainer_from_image_patterns returns None | ||
| assert trainer is None | ||
| else: | ||
| assert trainer is not None | ||
| assert trainer.framework.value == expected_framework.value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small suggestion to replace expected_framework with None for no-match cases. I think this make the behavior of the function being tested clearer for readers.
| # Edge cases - no match (should fall back to default) | |
| ("unknown-image:latest", types.TrainerFramework.TORCH), | |
| ("", types.TrainerFramework.TORCH), | |
| ("nginx:latest", types.TrainerFramework.TORCH), | |
| ("ubuntu:20.04", types.TrainerFramework.TORCH), | |
| ], | |
| ) | |
| def test_trainer_detection_from_image_patterns( | |
| self, image_name, expected_framework | |
| ): | |
| """Test trainer detection using image pattern matching with various case scenarios.""" | |
| trainer = utils._detect_trainer_from_image_patterns(image_name) | |
| if expected_framework == types.TrainerFramework.TORCH and trainer is None: | |
| # For unknown images, the _detect_trainer function should return default | |
| # but _detect_trainer_from_image_patterns returns None | |
| assert trainer is None | |
| else: | |
| assert trainer is not None | |
| assert trainer.framework.value == expected_framework.value | |
| # Edge cases - no match | |
| ("unknown-image:latest", None), | |
| ("", None), | |
| ("nginx:latest", None), | |
| ("ubuntu:20.04", None), | |
| ], | |
| ) | |
| def test_trainer_detection_from_image_patterns( | |
| self, image_name, expected_framework | |
| ): | |
| """Test trainer detection using image pattern matching with various case scenarios.""" | |
| trainer = utils._detect_trainer_from_image_patterns(image_name) | |
| if expected_framework is None: | |
| # For unknown images _detect_trainer_from_image_patterns returns None | |
| assert trainer is None | |
| else: | |
| assert trainer is not None | |
| assert trainer.framework.value == expected_framework.value |
| # Trainer framework constants for easy reference | ||
| class TrainerFramework(Enum): | ||
| """Trainer framework constants.""" | ||
| TORCH = "torch" | ||
| DEEPSPEED = "deepspeed" | ||
| MLX = "mlx" | ||
| TORCHTUNE = "torchtune" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is exactly the same as the Framework enum. Can we delete this and use Framework instead?
ed91f02 to
e16b9c6
Compare
|
@eoinfennessy made all the suggested changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jskswamy LGTM! Thank you!
|
@eoinfennessy: changing LGTM is restricted to collaborators In response to this:
Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this contribution @jskswamy 🎉
| if ml_policy.torch and ml_policy.torch.num_proc_per_node is not None: | ||
| num_proc = ml_policy.torch.num_proc_per_node.actual_instance | ||
| if isinstance(num_proc, int): | ||
| trainer.accelerator_count = num_proc | ||
| elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node: | ||
| elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to add is not None here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1. Torch Policy Check (if trainer_container.accelerator_count is not None)
# Essential: Prevents AttributeError when accessing None.actual_instance
# Without this check: None.actual_instance would raise AttributeError
if trainer_container.accelerator_count is not None:
if hasattr(trainer_container.accelerator_count, 'actual_instance'):
trainer.accelerator_count = trainer_container.accelerator_count.actual_instance2. MPI Policy Check (if trainer_container.mpi_policy is not None)
# Essential: Prevents setting accelerator_count to None when user explicitly sets it
# Without this check: trainer.accelerator_count would be overwritten to None
if trainer_container.mpi_policy is not None:
trainer.accelerator_count = trainer_container.mpi_policy.num_procs3. Semantic Correctness
These checks ensure that:
- User-provided values are preserved and not overwritten
- We don't attempt operations on
Noneobjects - The logic follows "only apply changes if the field is actually set"
Code Comments Added:
I've added explanatory comments to each check to make their necessity clear for future maintainers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jskswamy I just meant that those 2 lines are the same in Python, isn't ?
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node:
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a subtle but important distinction in Python.
The is not None Check is Necessary
The current code is correct because 0 is a valid and meaningful value for num_proc_per_node:
# Current correct implementation
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None:
trainer.accelerator_count = ml_policy.mpi.num_proc_per_nodeWhy Truthiness Checking Would Break CPU-Only Training
If we used truthiness checking instead:
# This would be problematic
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node:
trainer.accelerator_count = ml_policy.mpi.num_proc_per_nodeExample Scenarios:
Scenario 1: CPU-Only Training (0 accelerators)
ml_policy.mpi.num_proc_per_node = 0 # Explicitly set to CPU-only
# With truthiness check:
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node: # 0 is falsy!
trainer.accelerator_count = ml_policy.mpi.num_proc_per_node # ❌ Never executes
# With is not None check:
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None: # 0 is not None!
trainer.accelerator_count = ml_policy.mpi.num_proc_per_node # ✅ Executes correctlyScenario 2: GPU Training (4 accelerators)
ml_policy.mpi.num_proc_per_node = 4 # Explicitly set to 4 GPUs
# Both approaches work correctly:
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node: # 4 is truthy ✅
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None: # 4 is not None ✅Scenario 3: Not Set (defaults to UNKNOWN)
ml_policy.mpi.num_proc_per_node = None # Not explicitly set
# Both approaches work correctly:
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node: # None is falsy ✅
if ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None: # None is None ✅The Key Distinction
The is not None check properly distinguishes between:
- "Not set" (
None) → don't override accelerator count - "Explicitly set to 0" (
0) → override with 0 (CPU-only training) - "Explicitly set to positive number" → override with that number
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But why num_proc_per_node=0 is a valid value ?
We should not allow user to set such value or consider this as None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jskswamy Did you get a chance to check this comment ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for late reply! I've addressed this, kindly check the changes now
| return None | ||
|
|
||
|
|
||
| def _detect_trainer_from_image_patterns(image_name: str) -> Optional[types.Trainer]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tenzen-y @Electronic-Waste @astefanutti @jskswamy @eoinfennessy @franciscojavierarceo Do we see any concerns with regex approach ? It might be a good and simple method to start with, but I can imagine use cases where it wouldn't work. For example, users might have two DeepSpeed runtimes:
- One uses
torchrun - Another uses
mpirun.
Perhaps in the future we can support such scenarios.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, there are use cases where image name patterns alone wouldn't be sufficient.
Current Regex Approach — Pragmatic Starting Point
The regex approach implemented serves as a practical starting point that:
- ✅ Works immediately for the majority of common use cases
- ✅ Supports all official Kubeflow trainer images out of the box
- ✅ Provides sensible defaults without requiring users to specify trainer types manually
- ✅ Maintains backward compatibility with existing workflows
Future API-Based Enhancement
For advanced scenarios like your DeepSpeed example (torchrun vs mpirun variants), we can introduce explicit API controls that override the regex detection:
# Option 1: Explicit trainer specification
trainer = Trainer(
image="custom/deepspeed-runtime",
trainer_type=TrainerType.DEEPSPEED_MPI, # Override regex detection
# ... other configs
)
# Option 2: Runtime configuration
trainer = Trainer(
image="custom/deepspeed-runtime",
runtime_config=DeepSpeedConfig(launcher="mpirun"), # vs "torchrun"
# ... other configs
)Approach
The regex approach handles ~90% of use cases elegantly, while keeping the door open for API-based precision when needed.
Question: Which approach would you prefer to proceed with?
Option A: Keep the current regex-based detection and enhance it incrementally with API overrides when needed
Option B: Move to a more explicit API-first approach where users specify trainer types directly
Option C: Hybrid approach where regex provides defaults, but API allows explicit overrides from day one
I'm happy to implement the changes that would be most valuable for users. What are your thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep the regex approach initially, we can have discussion how to cover more complex use-cases.
We just need to ensure we document that in our docs: https://www.kubeflow.org/docs/components/trainer/user-guides/
ec4a1c2 to
e7f4425
Compare
This commit introduces optional dependencies specifically for testing purposes. The `pytest` and `pytest-mock` packages are added to the `pyproject.toml` file under the `optional-dependencies` section, allowing developers to easily install testing tools when needed. Additionally, a new `pytest.ini` configuration section is created to standardize test settings, including options for verbosity and test discovery patterns. Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
This commit introduces a new enumeration `TrainerFramework` to centralize the definitions of various trainer frameworks used in the Kubeflow SDK. The trainer configurations have been refactored into a dictionary `TRAINER_CONFIGS`, which maps each framework to its respective configuration, reducing duplication and improving maintainability. Additionally, the trainer detection logic has been enhanced to utilize image name patterns for identifying the appropriate trainer framework based on the container image name. This improves the robustness of trainer type detection and ensures backward compatibility with the existing `ALL_TRAINERS` mapping. - Added `TrainerFramework` enum for trainer framework constants. - Refactored trainer configurations into `TRAINER_CONFIGS`. - Enhanced trainer detection logic to support image name patterns. - Added unit tests for the new detection logic and configurations. Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Updated the TrainerFramework Enum to a more generic Framework Enum to improve code maintainability and clarity. This change simplifies the trainer configurations and associated functions by using the new Framework Enum, ensuring consistent references throughout the codebase. - Replaced TrainerFramework with Framework in types.py - Updated references in utils.py to reflect the new Enum - Adjusted test cases in test_utils.py to accommodate changes Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Refactor the test cases in `test_utils.py` to adjust the expected output for edge cases where no matching framework is found. This change ensures that the tests handle cases where the image does not correspond to any known framework by returning `None` instead of a default framework. Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Move test files from tests/ directory to be co-located with source files and split types-related tests into a separate file: - tests/test_utils.py → kubeflow/trainer/utils/utils_test.py - Extract types tests → kubeflow/trainer/types/types_test.py - Update pyproject.toml testpaths: ["tests"] → ["kubeflow"] - Remove tests/ directory This improves code organization by keeping tests next to the code they validate, making it easier to maintain test coverage when modifying source files. Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Remove underscore prefixes from detect_trainer_from_image_patterns() and detect_trainer() to follow established codebase conventions. Analysis shows no other utility functions in the codebase use underscore prefixes. Functions renamed: - _detect_trainer_from_image_patterns → detect_trainer_from_image_patterns - _detect_trainer → detect_trainer Update all function calls and tests accordingly. Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Remove generic 'torch' pattern matching and require explicit 'pytorch' in image names for better framework distinction. This prevents ambiguity between PyTorch and other torch-related libraries. - Remove regex pattern: r'^torch(?!tune)' - Keep only: r'pytorch' for PyTorch detection - Update test case: 'torch-custom:latest' → 'pytorch-torch-custom:latest' - Add test case: 'torch-custom:latest' now returns None This ensures clearer separation between PyTorch and TorchTune images. Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Add detailed comments explaining why 'is not None' checks are necessary in ML policy processing: 1. For torch: prevents AttributeError when accessing None.actual_instance 2. For MPI: prevents setting accelerator_count to None 3. Semantically: only override when user explicitly provides values These checks prevent runtime errors and ensure correct behavior when ML policies have undefined num_proc_per_node values. Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Eliminate ALL_TRAINERS and rely solely on regex pattern matching for trainer detection. This removes duplication between static mapping and TRAINER_CONFIGS while maintaining full functionality. - Remove ALL_TRAINERS from types.py - Simplify detect_trainer(): regex patterns → DEFAULT_TRAINER fallback - Update tests to verify official images work with regex patterns All official Kubeflow images correctly detected by regex, ensuring no breaking changes while reducing architectural complexity. The regex patterns now serve as the single source of truth. Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
- Remove uv.lock file - Remove test dependencies from pyproject.toml - Remove pytest configuration from pyproject.toml - Keep only core trainer detection improvements and tests This ensures the PR focuses solely on trainer detection enhancements. Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
e7f4425 to
b3aed48
Compare
| TRAINER_CONFIGS: Dict[Framework, Trainer] = { | ||
| Framework.TORCH: Trainer( | ||
| trainer_type=TrainerType.CUSTOM_TRAINER, | ||
| framework=Framework.TORCH, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need to keep framework argument given that TRAINER_CONFIGS Dict has the Framework type in the Dict key.
| framework=Framework.TORCH, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding the framework field in the Trainer class, I'd like to share my thoughts on why this field exists and why it serves a legitimate purpose:
The framework Field Has Critical Importance
After investigating the codebase, I discovered that the Trainer class and framework field were pre-existing before this PR. The field was intentionally designed to serve specific purposes:
Critical Importance for API Design
The framework field is essential for maintaining a clean, self-contained API:
- Object Identity: A
Trainerobject must "know" what framework it represents without external context - API Completeness: When users receive a
Trainerobject, they can immediately determine its framework without reverse-engineering from other fields - Serialization: The field is crucial for JSON serialization/deserialization of trainer objects
- Debugging & Logging: Essential for meaningful error messages and debugging information
Self-Contained Data Structure
The framework field makes Trainer objects self-contained and self-documenting:
# Example: A Trainer object "knows" what framework it represents
trainer = TRAINER_CONFIGS[Framework.DEEPSPEED]
# Self-documenting: The object tells us what it is
print(f"Using {trainer.framework} trainer with {trainer.trainer_type}")
# Output: "Using Framework.DEEPSPEED trainer with TrainerType.CUSTOM_TRAINER"
# Without the field, we'd need external context to know what framework this is
# We'd have to track which dictionary key was used to create this trainerBreaking Changes Would Be Required
Removing the field would require:
- Modifying any code that relies on the field for framework identification
- Potentially breaking API consumers who expect this field
- Adding complex lookup logic to determine framework from other properties
Architectural Integrity
The field maintains the principle of encapsulation — Trainer object should contain all information about itself, including what framework it represents.
Why Dictionary Instead of Array?
The choice of using TRAINER_CONFIGS: Dict[Framework, Trainer] instead of an array of trainers was a performance and design optimization:
Performance Benefits
# Current efficient approach with dictionary
trainer = TRAINER_CONFIGS[Framework.DEEPSPEED] # O(1) lookup
framework = trainer.framework # Direct access
# Alternative inefficient approach with array
def find_trainer_by_framework(framework):
for trainer in TRAINER_ARRAY: # O(n) search
if trainer.framework == framework:
return trainerDesign Benefits
- Fast Lookup: O(1) constant time access instead of O(n) linear search
- Type Safety: Dictionary keys ensure we only access valid frameworks
- Explicit Mapping: Clear relationship between framework and trainer configuration
- Extensibility: Easy to add new frameworks without changing lookup logic
My Take
The framework field serves critical architectural purposes for API design and object encapsulation. The dictionary structure provides performance benefits, but the field itself is essential for maintaining clean, self-contained objects.
Removing the field would break the original design intent, make the API less clean and efficient, and potentially introduce breaking changes. The field was intentionally designed this way for good reasons, and I believe we should keep it to maintain the integrity of the API design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that we should have dict to represent all Trainers where key is the Framework name and value is the Trainer object.
The question is should we also keep framework argument in the Trainer object. This is mostly used to just show users what framework this Trainer is using.
I am fine to keep it for now.
WDYT @szaher @astefanutti @Electronic-Waste ?
| trainer = detect_trainer_from_image_patterns(image_name) | ||
| if trainer: | ||
| return trainer | ||
|
|
||
| # 2. Fall back to DEFAULT_TRAINER | ||
| return copy.deepcopy(types.DEFAULT_TRAINER) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, this could be simplified if detect_trainer_from_image_patterns just return copy.deepcopy(types.DEFAULT_TRAINER) instead of None.
Can you just keep all of the required code to extract trainer in the get_trainer_from_image() function, which accepts image_name as input?
That will make our unit tests easier to maintain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made necessary changes to simplify the detect_trainer_from_image_patterns function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jskswamy Sorry for the late reply, I meant can you just use this code snippet ?
def get_runtime_trainer(....):
....
image_name = trainer_container.image.split(":")[0]
trainer = get_trainer_from_image(image_name)
def get_trainer_from_image(image_name: str) -> types.Trainer:
"""
Detect trainer type based on image name patterns using regex.
This method uses pattern matching on the image name to determine
the likely trainer type.
Args:
image_name: The container image name.
Returns:
Trainer: Trainer object if detected, otherwise the DEFAULT_TRAINER is returned.
"""
# DeepSpeed patterns
if re.search(r"deepspeed", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.DEEPSPEED])
# MLX patterns
if re.search(r"mlx", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.MLX])
# TorchTune patterns (check before PyTorch to avoid conflicts)
if re.search(r"torchtune", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCHTUNE])
# PyTorch patterns - require explicit "pytorch" in image name for clarity
if re.search(r"pytorch", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCH])
return copy.deepcopy(types.DEFAULT_TRAINER)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've simplified the function as per your suggestion, kindly check
…sulation - Add optional default parameter to detect_trainer_from_image_patterns() - Handle copy.deepcopy() internally for better encapsulation - Remove boilerplate code from detect_trainer() function - Add comprehensive unit tests with proper separation of concerns - Maintain backward compatibility with existing behavior Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
7f6fa23 to
d3c0043
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late review, I think we almost ready to merge this.
@szaher @kramaranya @briangallagher @eoinfennessy Can you take a look as well please ?
| trainer = detect_trainer_from_image_patterns(image_name) | ||
| if trainer: | ||
| return trainer | ||
|
|
||
| # 2. Fall back to DEFAULT_TRAINER | ||
| return copy.deepcopy(types.DEFAULT_TRAINER) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jskswamy Sorry for the late reply, I meant can you just use this code snippet ?
def get_runtime_trainer(....):
....
image_name = trainer_container.image.split(":")[0]
trainer = get_trainer_from_image(image_name)
def get_trainer_from_image(image_name: str) -> types.Trainer:
"""
Detect trainer type based on image name patterns using regex.
This method uses pattern matching on the image name to determine
the likely trainer type.
Args:
image_name: The container image name.
Returns:
Trainer: Trainer object if detected, otherwise the DEFAULT_TRAINER is returned.
"""
# DeepSpeed patterns
if re.search(r"deepspeed", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.DEEPSPEED])
# MLX patterns
if re.search(r"mlx", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.MLX])
# TorchTune patterns (check before PyTorch to avoid conflicts)
if re.search(r"torchtune", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCHTUNE])
# PyTorch patterns - require explicit "pytorch" in image name for clarity
if re.search(r"pytorch", image_name, re.IGNORECASE):
return copy.deepcopy(types.TRAINER_CONFIGS[types.Framework.TORCH])
return copy.deepcopy(types.DEFAULT_TRAINER)| if ml_policy.torch and ml_policy.torch.num_proc_per_node is not None: | ||
| num_proc = ml_policy.torch.num_proc_per_node.actual_instance | ||
| if isinstance(num_proc, int): | ||
| trainer.accelerator_count = num_proc | ||
| elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node: | ||
| elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jskswamy Did you get a chance to check this comment ?
Simplify trainer detection API by removing optional default parameter and always returning a Trainer object. The function now directly returns DEFAULT_TRAINER when no regex patterns match, eliminating the need for None handling in calling code. Changes: - Rename function to get_trainer_from_image for clarity - Remove optional default parameter from function signature - Always return types.Trainer instead of Optional[types.Trainer] - Update all test cases to expect DEFAULT_TRAINER for unknown images - Simplify detect_trainer() function logic Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
Changes: - For torch: check actual_instance value truthiness, not just object existence - For MPI: already correctly validates the direct value - Zero values (0) are now ignored (treated as None) - Negative values are trusted as explicit user input - Update test cases to reflect new behavior Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may not have the full context, so please pardon my naive question, why those metadata, mostly the framework type, could not come from the training runtime itself, in the form of an annotation or a label?
For custom images, as a platform admin / user, I would understand I need to provide some hints.
Previously we talked with @tenzen-y and @Electronic-Waste about introducing labels to the runtime that define framework type: https://cloud-native.slack.com/archives/C0742LDFZ4K/p1741266604716149?thread_ts=1741263570.091899&cid=C0742LDFZ4K But we decided to not add more labels since this information can be retrieved from image and APIs. |
The implicit contract based on the image name might prove fragile for users and downstream projects. I understand the regex-based heuristics could provide a nice last-resort as it stands, but making that contract explicit seems simpler and more robust. I still fail to see why it couldn't be enforced in the training runtime API, or at least the SDK would only fallback to the regex-based heuristics if that API contract is made optional. |
@astefanutti I think, we need to figure out why we even expose Runtime's trainer to the user. sdk/python/kubeflow/trainer/types/types.py Line 166 in fa5778b
Information such as:
Do we think that we can refactor some of this and remove the |
Right, that equally applies to the
For built-in trainers, it seems there is a tight coupling between the trainer and the runtime, so maybe folding things into runtime as the "source-of-truth" would be better. |
So do you mean that Runtime should tell users whether it is meant for |
Yes, one way or another. How a runtime is supposed to be used in the SDK is logically defined by the runtime, that includes the type of trainer (built-in, custom) and the framework (PyTorch, JAX, TorchTune, ...). |
|
@astefanutti Do you think that framework information is still useful for SDK users if they can always run |
No, though it'd be needed for checking the typed configuration passed by users for built-in trainers is compatible with the training runtime? |
This is correct, additionally we can't run the client.train(
runtime=Runtime(name="torchtune-llama3.2-3b")
)Also, I don't think that users needs to know about installed packages in such runtimes, since they can only modify the config (e.g. fine-tuning parameters), but not the runtime packages. Maybe for BuiltinTrainer runtime we should have two labels:
If we don't want to introduce 2nd label, we can just tell users to rely on runtime name. Thoughts @tenzen-y @astefanutti @Electronic-Waste @rudeigerc @szaher @kramaranya ? |
|
I think, we should refactor our Runtime class: https://github.com/kubeflow/sdk/blob/main/python/kubeflow/trainer/types/types.py#L176-L179 |
Yes, labels seem the most straightforward approach. There is already the |
|
Agreed that it would be better to add APIs to TrainingRuntime and ClusterTrainingRuntime to specify the framework instead of relying on image names and regex checks. Adding trainer type would also be useful. But why use labels instead of adding One idea for this that would use cross-field validation to ensure one and only one of Probably best to consider the exact APIs alongside work on kubeflow/trainer#2752. |
@eoinfennessy I agree with you it's a possible alternative. Labels are flexible and enable listing runtimes by label selectors, but we could conceptually consider these metadata as part of the spec. |
Ah, I hadn't considered that. Yes, that could help improve the UX of the |
I am not sure if we should continue to maintain this label. IIRC, @tenzen-y has concerns introducing this label in the runtimes.
We can also use field selector, if we introduce a new API in the runtime. What are the pros and cons to add this property under labels or APIs ? |
You're right, it may not be a good example "semantically".
I'm not sure custom fields from CRDs are indexed. It might be only few fields from core APIs.
I would say labels are more "free-form" and not as strictly part of the API contract compared to fields. |
You are right @astefanutti, here is the list of supported fields: https://kubernetes.io/docs/concepts/overview/working-with-objects/field-selectors/#list-of-supported-fields
Let me remove this label for now, unless we design better way to explain users the accelerator types in the Runtime. @Electronic-Waste @astefanutti @tenzen-y Any concerns to introduce these three labels to our runtimes for now ? trainer.kubeflow.org/trainer-type: custom
or
trainer.kubeflow.org/trainer-type: builtin
trainer.kubeflow.org/builtin-config: torchtuneAlternatively, we can introduce |
I think that's a good start. I only wonder whether those should be within the
Actually |
I think, if we keep the
@astefanutti If we establish the contract that builtin configs contain framework name in the DataClass name, the |
That would be good yes. Having a mapping between framework name and DataClass name in the SDK would be perfectly acceptable I think. |
Problem
Custom DeepSpeed Docker images were losing the
mpiruncommand and falling back totorchruninstead. Theget_runtime_trainerfunction only used a hardcodedALL_TRAINERSmapping, so any custom images not in this mapping would default to the PyTorch trainer configuration.Solution
Enhanced trainer detection with regex-based pattern matching as a fallback mechanism:
ALL_TRAINERSmapping for exact matches(?i)deepspeed(case-insensitive)(?i)mlx(?i)torchtune(?i)(pytorch|torch)(but not torchtune)Key Changes
_detect_trainer_from_image_patterns()function with case-insensitive regex matching_detect_trainer()to use pattern matching as fallbackcopy.deepcopy()to prevent shared state issues between trainer configurationsTesting
ALL_TRAINERSmappingThis ensures custom DeepSpeed images like
my-org/deepspeed-custom:latestcorrectly usempiruninstead of falling back totorchrun.This fixes the issue #29